import numpy as np


def load_data():
    lines = [['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],
             ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],
             ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],
             ['stop', 'posting', 'stupid', 'worthless', 'garbage'],
             ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],
             ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]
    # 1是违规言论,0合法言论
    y = np.array([0, 1, 0, 1, 0, 1])
    # 所有不重复单词的集合
    words = []
    for line in lines:
        # extend() 函数用于在列表末尾一次性追加另一个序列中的多个值（用新列表扩展原来的列表）
        words.extend(line)
    words = list(set(words))
    # 句子数字化,出现了某单词就是1,否则是0 one-hot编码
    x = np.zeros((len(lines), len(words)))
    for i in range(len(lines)):
        for j in range(len(words)):
            if words[j] in lines[i]:
                x[i, j] = 1
    return x, y


def train(x, y):
    # 求总体的违规率,等于 违规次数 / 总次数
    p1 = y.sum() / len(y)
    p0 = 1 - p1
    # 取对数概率
    p1 = np.log(p1)
    p0 = np.log(p0)
    # 根据y的值,切分x为正例和反例
    x_1 = x[y == 1]
    x_0 = x[y == 0]
    # 统计在正例中,所有词出现的次数,一个句子中出现多次也只算1次
    # 也就是说,是每个词出现的句子数
    # 最后要加1是为了避免0的情况,也就是说,所有词最少出现1次
    p1_given_word = x_1.sum(axis=0) + 1
    # 上面统计的是次数,除以总次数就等于概率了
    # 也就是每个词,出现在正例句子中的概率
    p1_given_word = p1_given_word / p1_given_word.sum()
    # 取对数概率
    p1_given_word = np.log(p1_given_word)
    # p0_given_word的计算同理
    p0_given_word = x_0.sum(axis=0) + 1
    p0_given_word = p0_given_word / p0_given_word.sum()
    p0_given_word = np.log(p0_given_word)
    return p1_given_word, p0_given_word, p1, p0


def pred(x, p1_given_word, p0_given_word, p1, p0):
    # 本来p1_given_x应该是每个单词属于p1的概率连乘, 但是因为前面取了对数, 所以这里求和就可以了
    p1_given_x = x.dot(p1_given_word) + p1
    p0_given_x = x.dot(p0_given_word) + p0
    return 1 if p1_given_x > p0_given_x else 0


x, y = load_data()
p1_given_word, p0_given_word, p1, p0 = train(x, y)

currect = 0
for xi, yi in zip(x, y):
    if pred(xi, p1_given_word, p0_given_word, p1, p0) == yi:
        currect += 1
print(currect / len(x))
